package cn.apotato.modules.mybatis.plus.handler;

import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeHandler;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.URL;
import java.sql.*;
import java.time.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 二维数组处理程序
 *
 * @author 胡晓鹏
 * &#064;date  2023/04/17
 */
public abstract class BaseTwoArrayHandler<T> implements TypeHandler<T[][]> {

    private static final ConcurrentHashMap<Class<?>, String> STANDARD_MAPPING;

    static {
        STANDARD_MAPPING = new ConcurrentHashMap<>();
        STANDARD_MAPPING.put(BigDecimal.class, JdbcType.NUMERIC.name());
        STANDARD_MAPPING.put(BigInteger.class, JdbcType.BIGINT.name());
        STANDARD_MAPPING.put(boolean.class, JdbcType.BOOLEAN.name());
        STANDARD_MAPPING.put(Boolean.class, JdbcType.BOOLEAN.name());
        STANDARD_MAPPING.put(byte[].class, JdbcType.VARBINARY.name());
        STANDARD_MAPPING.put(byte.class, JdbcType.TINYINT.name());
        STANDARD_MAPPING.put(Byte.class, JdbcType.TINYINT.name());
        STANDARD_MAPPING.put(Calendar.class, JdbcType.TIMESTAMP.name());
        STANDARD_MAPPING.put(java.sql.Date.class, JdbcType.DATE.name());
        STANDARD_MAPPING.put(java.util.Date.class, JdbcType.TIMESTAMP.name());
        STANDARD_MAPPING.put(double.class, JdbcType.DOUBLE.name());
        STANDARD_MAPPING.put(Double.class, JdbcType.DOUBLE.name());
        STANDARD_MAPPING.put(float.class, JdbcType.REAL.name());
        STANDARD_MAPPING.put(Float.class, JdbcType.REAL.name());
        STANDARD_MAPPING.put(int.class, JdbcType.INTEGER.name());
        STANDARD_MAPPING.put(Integer.class, JdbcType.INTEGER.name());
        STANDARD_MAPPING.put(LocalDate.class, JdbcType.DATE.name());
        STANDARD_MAPPING.put(LocalDateTime.class, JdbcType.TIMESTAMP.name());
        STANDARD_MAPPING.put(LocalTime.class, JdbcType.TIME.name());
        STANDARD_MAPPING.put(long.class, JdbcType.BIGINT.name());
        STANDARD_MAPPING.put(Long.class, JdbcType.BIGINT.name());
        STANDARD_MAPPING.put(OffsetDateTime.class, JdbcType.TIMESTAMP_WITH_TIMEZONE.name());
        STANDARD_MAPPING.put(OffsetTime.class, JdbcType.TIME_WITH_TIMEZONE.name());
        STANDARD_MAPPING.put(Short.class, JdbcType.SMALLINT.name());
        STANDARD_MAPPING.put(String.class, JdbcType.VARCHAR.name());
        STANDARD_MAPPING.put(Time.class, JdbcType.TIME.name());
        STANDARD_MAPPING.put(Timestamp.class, JdbcType.TIMESTAMP.name());
        STANDARD_MAPPING.put(URL.class, JdbcType.DATALINK.name());
    }

    /**
     * 得到类型
     *
     * @return {@link Class}<{@link T}>
     */
    abstract protected Class<T> getType();


    /**
     * 设置参数
     *
     * @param ps        事先准备好声明中
     * @param i         我
     * @param parameter 参数
     * @param jdbcType  jdbc类型
     * @throws SQLException sqlexception异常
     */
    @Override
    public void setParameter(PreparedStatement ps, int i, T[][] parameter, JdbcType jdbcType) throws SQLException {
        Connection con = ps.getConnection();
        Array array = con.createArrayOf(resolveTypeName(getType()), parameter);
        ps.setArray(i, array);
    }

    protected String resolveTypeName(Class<?> type) {
        return STANDARD_MAPPING.getOrDefault(type, JdbcType.JAVA_OBJECT.name());
    }

    @Override
    public T[][] getResult(ResultSet resultSet, String columnName) throws SQLException {
        return getArray(resultSet.getArray(columnName));
    }

    @Override
    public T[][] getResult(ResultSet resultSet, int i) throws SQLException {
        return getArray(resultSet.getArray(i));
    }

    @Override
    public T[][] getResult(CallableStatement callableStatement, int i) throws SQLException {
        return getArray(callableStatement.getArray(i));
    }

    private T[][] getArray(Array array) throws SQLException {
        if (array == null) {
            return null;
        }
        T[] rows = (T[]) array.getArray();
        if (rows == null || rows.length == 0) {
            return null;
        }
        int rowCount = getRowCount(array);
        int columnCount = getColumnCount(array);
        T[][] result = (T[][]) java.lang.reflect.Array.newInstance(getType(), rowCount, columnCount);
        for (int i = 0; i < rowCount; i++) {
            Object row = rows[i];
            if (row == null) {
                continue;
            }
            T[] columns = (T[]) row;
            if (columns.length > 0) {
                System.arraycopy(columns, 0, result[i], 0, columns.length);
            }
        }
        return result;
    }

    public int getRowCount(Array array) throws SQLException {
        if (array == null) {
            return 0;
        }
        Object[] rows = (Object[]) array.getArray();
        return rows.length;
    }

    public int getColumnCount(Array array) throws SQLException {
        if (array == null) {
            return 0;
        }
        Object[] rows = (Object[]) array.getArray();
        return Arrays.stream(rows)
                .filter(Objects::nonNull)
                .map(item -> ((Object[]) item).length)
                .max(Integer::compareTo)
                .orElse(0);
    }
}
